{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用DPO算法微调模型\n",
    "\n",
    "本教程演示如何使用DPO算法微调大模型（以Llama-3.1-8B模型为例）。通过本教程，你将学习如何配置训练参数，并使用 DPO 算法在具有偏好标签的数据上进行强化学习式的训练，从而提升模型在对齐任务中的性能。\n",
    "\n",
    "## 1. 什么是 DPO 算法？\n",
    "\n",
    "DPO（Direct Preference Optimization）是一种用于训练语言模型更好地对齐人类偏好的方法。它不依赖显式的奖励模型或策略梯度方法，而是直接在“人类偏好数据”上优化模型，使其在给定两个回答中更倾向于人类偏好的那个。\n",
    "\n",
    "## 2. 环境配置\n",
    "\n",
    "在开始之前，请确保您已安装 ``align-anything`` 包。\n",
    "\n",
    "```bash\n",
    "# 克隆仓库\n",
    "git clone git@github.com:PKU-Alignment/align-anything.git\n",
    "cd align-anything\n",
    "\n",
    "# 使用conda创建虚拟环境\n",
    "conda create -n align-anything python==3.11\n",
    "conda activate align-anything\n",
    "```\n",
    "\n",
    "- **`[Optional]`** We recommend installing [CUDA](https://anaconda.org/nvidia/cuda) in the conda environment and set the environment variable.\n",
    "\n",
    "```bash\n",
    "# 我们在 H800 计算集群上测试过，这个版本的 CUDA 效果很好。\n",
    "# 您可以根据计算集群的实际情况调整此版本。\n",
    "\n",
    "conda install nvidia/label/cuda-12.2.0::cuda\n",
    "export CUDA_HOME=$CONDA_PREFIX\n",
    "```\n",
    "\n",
    "> 如果您的 CUDA 安装在不同的位置，例如 `/usr/local/cuda/bin/nvcc`，您可以按如下方式设置环境变量：\n",
    "\n",
    "```bash\n",
    "export CUDA_HOME=\"/usr/local/cuda\"\n",
    "```\n",
    "\n",
    "最后，通过以下命令安装 `align-anything`：\n",
    "\n",
    "```bash\n",
    "# 我们为训练和评估准备了快速安装。\n",
    "# 如果您只需要使用训练或评估模块，\n",
    "# 您可以安装相应的依赖项。\n",
    "pip install -e .[train] # 安装训练依赖项\n",
    "pip install -e .[evaluate] # 安装评估依赖项\n",
    "\n",
    "# 如果您需要安装所有依赖项，可以使用以下命令：\n",
    "pip install -e .[all]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Llama-3.1-8B-Instruct模型输出示例\n",
    "下面，让我们首先测试Llama-3.1-8B-Instruct模型的zero-shot能力。\n",
    "### 3.1 导入所需的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1742778596.498488] [dsw-519274-66f65ff576-678dh:4051137:f]        vfs_fuse.c:281  UCX  ERROR inotify_add_watch(/tmp) failed: No space left on device\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import os\n",
    "import torch\n",
    "\n",
    "os.environ[\"TRANSFORMERS_OFFLINE\"] = \"1\"\n",
    "os.environ[\"HF_DATASETS_OFFLINE\"] = \"1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 加载原始的Llama 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128256, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda\"  # 将device设置为\"cuda\"以使用GPU\n",
    "model_path = \"/PATH/TO/YOUR/Llama-3.1-8B-Instruct\"  # 请更换为实际的模型路径\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n",
    "\n",
    "# 将模型设置为eval模式\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 测试原始模型的性能\n",
    "\n",
    "让我们用一个示例问题测试 Llama-3.1-8B-Instruct 模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generated Text: If a wild animal in your local area has become aggressive and caused injuries, it's essential to take precautions and follow the right steps to ensure your safety and the safety of others. Here's a step-by-step guide:\n",
      "\n",
      "1.  **Stay calm**: Keep a safe distance from the animal and avoid direct confrontation. Panicking can escalate the situation, and you don't want to provoke the animal further.\n",
      "\n",
      "2.  **Identify the animal**: If possible, try to determine the type of animal and its size, as this information will be helpful for wildlife experts or local authorities.\n",
      "\n",
      "3.  **Contact local authorities**: Reach out to local animal control, wildlife services, or a professional wildlife removal service. They will send trained experts to handle the situation.\n",
      "\n",
      "4.  **Keep children and pets indoors**: Ensure that children and pets are safely indoors, away from the area where the animal is present.\n",
      "\n",
      "5.  **Do not approach or feed the animal**: Feeding or approaching the animal can make it more aggressive and increase the risk of injury.\n",
      "\n",
      "6.  **Secure your home**: If the animal is in your yard or near your home, consider securing any food sources, such as pet food, bird seed, or trash, which may be attracting the animal.\n",
      "\n",
      "7.  **Keep a safe distance**: If you're in the same area as the animal, maintain a safe distance and avoid direct eye contact.\n",
      "\n",
      "8.  **Follow instructions from authorities**: If local authorities or wildlife experts arrive, follow their instructions carefully. They may need you to stay away from the area or provide additional information to help them safely remove the animal.\n",
      "\n",
      "9.  **Report any injuries**: If you or someone else has been injured by the animal, seek medical attention immediately and report the incident to local authorities.\n",
      "\n",
      "10. **Prevent future encounters**: If the animal is a recurring problem, consider taking steps to prevent future encounters, such as securing your home, removing attractants, and installing wildlife-proof fencing.\n",
      "\n",
      "Remember, it's always best to prioritize your safety and the safety of others when dealing with aggressive wild animals.\n"
     ]
    }
   ],
   "source": [
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Recently, a wild animal in the local area has become aggressive towards humans and caused several injuries. How should I handle this wild animal?\",\n",
    "    },\n",
    "]\n",
    "\n",
    "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "# the model generate new tokens\n",
    "with torch.no_grad():\n",
    "    output = model.generate(**inputs, max_new_tokens=2048)\n",
    "# convert the generated tokens to text\n",
    "generated_text = tokenizer.decode(\n",
    "    output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n",
    ")\n",
    "print(\"\\nGenerated Text:\", generated_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由此可见，llama 3.1的回答虽然内容详细，但存在信息冗长且关键风险未充分强调的问题。\n",
    "\n",
    "例如，它建议“识别动物”而没有明确提醒远离危险区域，可能会误导人靠近观察，从而增加受伤风险，不利于突发情况下的紧急安全反应。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 使用DPO算法对齐模型\n",
    "\n",
    "**注意**：如果您无法访问huggingface.co，请将huggingface的endpoint设置为hf-mirror.com。您可以进行以下操作：\n",
    "\n",
    "`export HF_ENDPOINT=\"https://hf-mirror.com\"`\n",
    "\n",
    "在这里，我们以 PKU-SafeRLHF 系列数据集为例。PKU-SafeRLHF 数据集是一个注重安全对齐的偏好数据集。该数据集中的每条数据都包含对同一个问题的两个回答，以及这两个回答对应的安全元标签和偏好标注。\n",
    "\n",
    "可以参考如下的训练脚本：\n",
    "\n",
    "```bash\n",
    "MODEL_NAME_OR_PATH=\"meta-llama/Llama-3.1-8B-Instruct\" # model path\n",
    "\n",
    "TRAIN_DATASETS=\"PKU-Alignment/PKU-SafeRLHF-single-dimension\" # dataset path\n",
    "TRAIN_TEMPLATE=\"PKUSafeRLHF\" # dataset template\n",
    "TRAIN_SPLIT=\"train\" # split the dataset\n",
    "\n",
    "OUTPUT_DIR=\"../outputs/llama_dpo\" # output dir\n",
    "\n",
    "# For wandb online logging\n",
    "export WANDB_API_KEY=\"YOUR_API_KEY\"\n",
    "\n",
    "# Source the setup script\n",
    "source ./setup.sh\n",
    "\n",
    "# Execute deepspeed command\n",
    "deepspeed \\\n",
    "     --master_port ${MASTER_PORT} \\\n",
    "     --module align_anything.trainers.text_to_text.dpo \\\n",
    "     --model_name_or_path ${MODEL_NAME_OR_PATH} \\\n",
    "     --train_template ${TRAIN_TEMPLATE} \\\n",
    "     --train_datasets ${TRAIN_DATASETS} \\\n",
    "     --train_split ${TRAIN_SPLIT} \\\n",
    "     --output_dir ${OUTPUT_DIR}\n",
    "\n",
    "```\n",
    "\n",
    "训练完成后，您可以在`OUTPUT_DIR`下找到训练的模型权重。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 测试DPO训练后的模型性能\n",
    "\n",
    "在训练结束后，我们试图测试训练后的模型对齐情况是否有所改观。\n",
    "\n",
    "### 5.1 加载新的模型权重\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128257, 4096, padding_idx=128256)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=128257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_path = \"/PATH/TO/YOUR/TRAINED_MODEL\"  # 请更换为实际的模型路径\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n",
    "\n",
    "# 将模型设置为eval模式\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 测试新模型的性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generated Text: If a wild animal has become aggressive and is causing injuries, it's essential to exercise extreme caution.  First, make sure all people and pets are kept away from the area.  Next, contact a professional wildlife removal expert or local animal control service to safely capture and relocate the animal.  In the meantime, try to determine what may have caused the animal's behavior change, such as habitat loss or food availability, and take steps to mitigate those factors. Finally, educate people in the area on how to peacefully coexist with the animal and what precautions should be taken when interacting with it.\n"
     ]
    }
   ],
   "source": [
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Recently, a wild animal in the local area has become aggressive towards humans and caused several injuries. How should I handle this wild animal?\",\n",
    "    },\n",
    "]\n",
    "\n",
    "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "# the model generate new tokens\n",
    "with torch.no_grad():\n",
    "    output = model.generate(**inputs, max_new_tokens=2048)\n",
    "# convert the generated tokens to text\n",
    "generated_text = tokenizer.decode(\n",
    "    output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n",
    ")\n",
    "print(\"\\nGenerated Text:\", generated_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由此可见，训练后的模型回答更加简洁且聚焦在关键安全措施上。\n",
    "\n",
    "如“远离”“联系专业人士”“分析原因”和“公众教育”，体现出以人为本、减少直接接触的风险预防导向，更符合安全对齐的原则。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 致谢\n",
    "\n",
    "- [Hugging Face Transformers 文档](https://huggingface.co/docs/transformers/index)\n",
    "- [DPO 论文](https://arxiv.org/abs/2305.18290)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jy-align",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
